{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import scipy\n",
    "from imblearn.under_sampling import RandomUnderSampler\n",
    "from models import AnonymousColorDetector\n",
    "from utils import read_labeled_img"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [],
   "source": [
    "train_from_existed = False # 是否从现有数据训练，如果是的话，那就从dataset_file训练，否则就用data_dir里头的数据\n",
    "data_dir = \"data/dataset\" # 数据集，文件夹下必须包含`img`和`label`两个文件夹，放置相同文件名的图片和label\n",
    "dataset_file = \"data/dataset/dataset_2022-07-20_10-04.mat\"\n",
    "\n",
    "# color_dict = {(0, 0, 255): \"yangeng\", (255, 0, 0): 'beijing',(0, 255, 0): \"zibian\"} # 颜色对应的类别\n",
    "# color_dict = {(0, 0, 255): \"yangeng\"}\n",
    "color_dict = {(255, 0, 0): 'beijing'}\n",
    "# color_dict = {(0, 255, 0): \"zibian\"}\n",
    "label_index = {\"yangeng\": 1, \"beijing\": 0, \"zibian\":2} # 类别对应的序号\n",
    "show_samples = False # 是否展示样本\n",
    "\n",
    "# 定义一些训练量\n",
    "threshold = 3  # 正样本周围多大范围内的还算是正样本\n",
    "node_num = 20  # 如果使用ELM作为分类器物，有多少的节点\n",
    "negative_sample_num = None  # None或者一个数字，对应生成的负样本数量"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 读取数据"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "outputs": [],
   "source": [
    "dataset = read_labeled_img(data_dir, color_dict=color_dict, is_ps_color_space=False)\n",
    "if show_samples:\n",
    "    from utils import lab_scatter\n",
    "    lab_scatter(dataset, class_max_num=30000, is_3d=True, is_ps_color_space=False)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 数据平衡化"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "if len(dataset) > 1:\n",
    "    rus = RandomUnderSampler(random_state=0)\n",
    "    x_list, y_list = np.concatenate([v for k, v in dataset.items()], axis=0).tolist(), \\\n",
    "                     np.concatenate([np.ones((v.shape[0],)) * label_index[k] for k, v in dataset.items()], axis=0).tolist()\n",
    "    x_resampled, y_resampled = rus.fit_resample(x_list, y_list)\n",
    "    dataset = {\"inside\": np.array(x_resampled)}"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 模型训练"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [],
   "source": [
    "# 对数据进行预处理\n",
    "x = np.concatenate([v for k, v in dataset.items()], axis=0)\n",
    "negative_sample_num = int(x.shape[0] * 1.2) if negative_sample_num is None else negative_sample_num\n",
    "model = AnonymousColorDetector()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "152147it [42:48, 59.24it/s]                                                                         \n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "              precision    recall  f1-score   support\n",
      "\n",
      "           0       1.00      1.00      1.00     45644\n",
      "           1       1.00      1.00      1.00     38037\n",
      "\n",
      "    accuracy                           1.00     83681\n",
      "   macro avg       1.00      1.00      1.00     83681\n",
      "weighted avg       1.00      1.00      1.00     83681\n",
      "\n"
     ]
    }
   ],
   "source": [
    "if train_from_existed:\n",
    "    data = scipy.io.loadmat(dataset_file)\n",
    "    x, y = data['x'], data['y'].ravel()\n",
    "    model.fit(x, y=y, is_generate_negative=False, model_selection='dt')\n",
    "else:\n",
    "    world_boundary = np.array([0, 0, 0, 255, 255, 255])\n",
    "    model.fit(x, world_boundary, threshold, negative_sample_size=negative_sample_num, train_size=0.7,\n",
    "          is_save_dataset=True, model_selection='dt')\n",
    "model.save()"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## 模型的可视化"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "outputs": [],
   "source": [
    "model_path = \"dt_2022-07-21_17-19.model\"\n",
    "dataset_path = \"dataset_2022-07-21_17-19.mat\""
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "outputs": [
    {
     "ename": "FileNotFoundError",
     "evalue": "[Errno 2] No such file or directory: 'dataset_2022-07-21_17-19.mat'",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:39\u001B[0m, in \u001B[0;36m_open_file\u001B[1;34m(file_like, appendmat, mode)\u001B[0m\n\u001B[0;32m     38\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m---> 39\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfile_like\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m     40\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mIOError\u001B[39;00m \u001B[38;5;28;01mas\u001B[39;00m e:\n\u001B[0;32m     41\u001B[0m     \u001B[38;5;66;03m# Probably \"not found\"\u001B[39;00m\n",
      "\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'dataset_2022-07-21_17-19.mat'",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001B[1;31mFileNotFoundError\u001B[0m                         Traceback (most recent call last)",
      "Input \u001B[1;32mIn [22]\u001B[0m, in \u001B[0;36m<cell line: 1>\u001B[1;34m()\u001B[0m\n\u001B[1;32m----> 1\u001B[0m data \u001B[38;5;241m=\u001B[39m \u001B[43mscipy\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mio\u001B[49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mloadmat\u001B[49m\u001B[43m(\u001B[49m\u001B[43mdataset_path\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m      2\u001B[0m ground_truth \u001B[38;5;241m=\u001B[39m data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mx\u001B[39m\u001B[38;5;124m\"\u001B[39m][data[\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124my\u001B[39m\u001B[38;5;124m\"\u001B[39m]\u001B[38;5;241m.\u001B[39mravel()\u001B[38;5;241m==\u001B[39m\u001B[38;5;241m1\u001B[39m]\n",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:224\u001B[0m, in \u001B[0;36mloadmat\u001B[1;34m(file_name, mdict, appendmat, **kwargs)\u001B[0m\n\u001B[0;32m     87\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m     88\u001B[0m \u001B[38;5;124;03mLoad MATLAB file.\u001B[39;00m\n\u001B[0;32m     89\u001B[0m \n\u001B[1;32m   (...)\u001B[0m\n\u001B[0;32m    221\u001B[0m \u001B[38;5;124;03m    3.14159265+3.14159265j])\u001B[39;00m\n\u001B[0;32m    222\u001B[0m \u001B[38;5;124;03m\"\"\"\u001B[39;00m\n\u001B[0;32m    223\u001B[0m variable_names \u001B[38;5;241m=\u001B[39m kwargs\u001B[38;5;241m.\u001B[39mpop(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mvariable_names\u001B[39m\u001B[38;5;124m'\u001B[39m, \u001B[38;5;28;01mNone\u001B[39;00m)\n\u001B[1;32m--> 224\u001B[0m \u001B[38;5;28;01mwith\u001B[39;00m _open_file_context(file_name, appendmat) \u001B[38;5;28;01mas\u001B[39;00m f:\n\u001B[0;32m    225\u001B[0m     MR, _ \u001B[38;5;241m=\u001B[39m mat_reader_factory(f, \u001B[38;5;241m*\u001B[39m\u001B[38;5;241m*\u001B[39mkwargs)\n\u001B[0;32m    226\u001B[0m     matfile_dict \u001B[38;5;241m=\u001B[39m MR\u001B[38;5;241m.\u001B[39mget_variables(variable_names)\n",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\contextlib.py:135\u001B[0m, in \u001B[0;36m_GeneratorContextManager.__enter__\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m    133\u001B[0m \u001B[38;5;28;01mdel\u001B[39;00m \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39margs, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mkwds, \u001B[38;5;28mself\u001B[39m\u001B[38;5;241m.\u001B[39mfunc\n\u001B[0;32m    134\u001B[0m \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[1;32m--> 135\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mnext\u001B[39;49m\u001B[43m(\u001B[49m\u001B[38;5;28;43mself\u001B[39;49m\u001B[38;5;241;43m.\u001B[39;49m\u001B[43mgen\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m    136\u001B[0m \u001B[38;5;28;01mexcept\u001B[39;00m \u001B[38;5;167;01mStopIteration\u001B[39;00m:\n\u001B[0;32m    137\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mRuntimeError\u001B[39;00m(\u001B[38;5;124m\"\u001B[39m\u001B[38;5;124mgenerator didn\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mt yield\u001B[39m\u001B[38;5;124m\"\u001B[39m) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;28mNone\u001B[39m\n",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:17\u001B[0m, in \u001B[0;36m_open_file_context\u001B[1;34m(file_like, appendmat, mode)\u001B[0m\n\u001B[0;32m     15\u001B[0m \u001B[38;5;129m@contextmanager\u001B[39m\n\u001B[0;32m     16\u001B[0m \u001B[38;5;28;01mdef\u001B[39;00m \u001B[38;5;21m_open_file_context\u001B[39m(file_like, appendmat, mode\u001B[38;5;241m=\u001B[39m\u001B[38;5;124m'\u001B[39m\u001B[38;5;124mrb\u001B[39m\u001B[38;5;124m'\u001B[39m):\n\u001B[1;32m---> 17\u001B[0m     f, opened \u001B[38;5;241m=\u001B[39m \u001B[43m_open_file\u001B[49m\u001B[43m(\u001B[49m\u001B[43mfile_like\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mappendmat\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m\n\u001B[0;32m     18\u001B[0m     \u001B[38;5;28;01mtry\u001B[39;00m:\n\u001B[0;32m     19\u001B[0m         \u001B[38;5;28;01myield\u001B[39;00m f\n",
      "File \u001B[1;32m~\\miniconda3\\envs\\deepo\\lib\\site-packages\\scipy\\io\\matlab\\mio.py:45\u001B[0m, in \u001B[0;36m_open_file\u001B[1;34m(file_like, appendmat, mode)\u001B[0m\n\u001B[0;32m     43\u001B[0m     \u001B[38;5;28;01mif\u001B[39;00m appendmat \u001B[38;5;129;01mand\u001B[39;00m \u001B[38;5;129;01mnot\u001B[39;00m file_like\u001B[38;5;241m.\u001B[39mendswith(\u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.mat\u001B[39m\u001B[38;5;124m'\u001B[39m):\n\u001B[0;32m     44\u001B[0m         file_like \u001B[38;5;241m+\u001B[39m\u001B[38;5;241m=\u001B[39m \u001B[38;5;124m'\u001B[39m\u001B[38;5;124m.mat\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[1;32m---> 45\u001B[0m     \u001B[38;5;28;01mreturn\u001B[39;00m \u001B[38;5;28;43mopen\u001B[39;49m\u001B[43m(\u001B[49m\u001B[43mfile_like\u001B[49m\u001B[43m,\u001B[49m\u001B[43m \u001B[49m\u001B[43mmode\u001B[49m\u001B[43m)\u001B[49m, \u001B[38;5;28;01mTrue\u001B[39;00m\n\u001B[0;32m     46\u001B[0m \u001B[38;5;28;01melse\u001B[39;00m:\n\u001B[0;32m     47\u001B[0m     \u001B[38;5;28;01mraise\u001B[39;00m \u001B[38;5;167;01mIOError\u001B[39;00m(\n\u001B[0;32m     48\u001B[0m         \u001B[38;5;124m'\u001B[39m\u001B[38;5;124mReader needs file name or open file-like object\u001B[39m\u001B[38;5;124m'\u001B[39m\n\u001B[0;32m     49\u001B[0m     ) \u001B[38;5;28;01mfrom\u001B[39;00m \u001B[38;5;21;01me\u001B[39;00m\n",
      "\u001B[1;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'dataset_2022-07-21_17-19.mat'"
     ]
    }
   ],
   "source": [
    "data = scipy.io.loadmat(dataset_path)\n",
    "ground_truth = data[\"x\"][data[\"y\"].ravel()==1]"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "%matplotlib notebook\n",
    "model = AnonymousColorDetector(model_path)\n",
    "model.visualize(world_boundary = np.array([0, 0, 0, 255, 255, 255]),sample_size=5000,ground_truth=ground_truth)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}